/// Cross product of two 2D tensors.
#include "benchmark_config.h"
#include "gpu_timer.h"
#include <chrono>
#include <cuda.h>
#include <math.h>
#include <random>
#include <stdio.h>
#include <stdlib.h>
#include <type_traits>

// Reference
// https://pytorch.org/docs/stable/generated/torch.linalg.cross.html#torch.linalg.cross

template <typename T, typename StrideType>
__global__ void cross_kernel(int numel, T *out, const T *x1, const T *x2,
                             StrideType ostride, StrideType x1stride,
                             StrideType x2stride) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel;
       i += blockDim.x * gridDim.x) {

    auto *out_row = out + 3 * i;
    const auto *x1_row = x1 + 3 * i;
    const auto *x2_row = x2 + 3 * i;

    const T val0 = (x1_row[1 * x1stride] * x2_row[2 * x2stride] -
                    x1_row[2 * x1stride] * x2_row[1 * x2stride]);

    const T val1 = (x1_row[2 * x1stride] * x2_row[0 * x2stride] -
                    x1_row[0 * x1stride] * x2_row[2 * x2stride]);

    const T val2 = (x1_row[0 * x1stride] * x2_row[1 * x2stride] -
                    x1_row[1 * x1stride] * x2_row[0 * x2stride]);

    out_row[0 * ostride] = val0;
    out_row[1 * ostride] = val1;
    out_row[2 * ostride] = val2;
  }
}

template <typename T, typename StrideType>
__global__ void cross2_kernel(int numel, T *out, const T *x1, const T *x2,
                              StrideType ostride, StrideType x1stride,
                              StrideType x2stride) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel;
       i += blockDim.x * gridDim.x) {

    auto *out_row = out + 3 * i;
    const auto *x1_row = x1 + 3 * i;
    const auto *x2_row = x2 + 3 * i;

    const T x1_c0 = x1_row[0 * x1stride];
    const T x1_c1 = x1_row[1 * x1stride];
    const T x1_c2 = x1_row[2 * x1stride];
    const T x2_c0 = x2_row[0 * x2stride];
    const T x2_c1 = x2_row[1 * x2stride];
    const T x2_c2 = x2_row[2 * x2stride];

    const T val0 = x1_c1 * x2_c2 - x1_c2 * x2_c1;

    const T val1 = x1_c2 * x2_c0 - x1_c0 * x2_c2;

    const T val2 = x1_c0 * x2_c1 - x1_c1 * x2_c0;

    out_row[0 * ostride] = val0;
    out_row[1 * ostride] = val1;
    out_row[2 * ostride] = val2;
  }
}

template <typename T>
__global__ void cross3_kernel(int numel, T *out, const T *x1, const T *x2) {
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel;
       i += blockDim.x * gridDim.x) {

    auto *out_row = out + 3 * i;
    const auto *x1_row = x1 + 3 * i;
    const auto *x2_row = x2 + 3 * i;

    const T x1_c0 = x1_row[0];
    const T x1_c1 = x1_row[1];
    const T x1_c2 = x1_row[2];
    const T x2_c0 = x2_row[0];
    const T x2_c1 = x2_row[1];
    const T x2_c2 = x2_row[2];

    const T val0 = x1_c1 * x2_c2 - x1_c2 * x2_c1;

    const T val1 = x1_c2 * x2_c0 - x1_c0 * x2_c2;

    const T val2 = x1_c0 * x2_c1 - x1_c1 * x2_c0;

    out_row[0] = val0;
    out_row[1] = val1;
    out_row[2] = val2;
  }
}

template <typename T> void cal(int numel, T *out, const T *x1, const T *x2) {
  for (int i = 0; i < numel; ++i) {
    auto *out_row = out + 3 * i;
    const auto *x1_row = x1 + 3 * i;
    const auto *x2_row = x2 + 3 * i;
    const T x1_c0 = x1_row[0];
    const T x1_c1 = x1_row[1];
    const T x1_c2 = x1_row[2];
    const T x2_c0 = x2_row[0];
    const T x2_c1 = x2_row[1];
    const T x2_c2 = x2_row[2];

    const T val0 = x1_c1 * x2_c2 - x1_c2 * x2_c1;

    const T val1 = x1_c2 * x2_c0 - x1_c0 * x2_c2;

    const T val2 = x1_c0 * x2_c1 - x1_c1 * x2_c0;

    out_row[0] = val0;
    out_row[1] = val1;
    out_row[2] = val2;
  }
}

template <typename T> void eval(const int nrows, const int repeat) {
  const int num_elems = nrows * 3;
  const int size_bytes = num_elems * sizeof(T);

  T *a, *b, *o, *o2, *o3;
  T *o_cpu;
  a = (T *)malloc(size_bytes);
  b = (T *)malloc(size_bytes);
  o = (T *)malloc(size_bytes);
  o2 = (T *)malloc(size_bytes);
  o3 = (T *)malloc(size_bytes);
  o_cpu = (T *)malloc(size_bytes);

  std::default_random_engine g(123);
  std::uniform_real_distribution<T> distr(-2.f, 2.f);
  for (int i = 0; i < num_elems; i++) {
    a[i] = distr(g);
    b[i] = distr(g);
  }

  T *d_a, *d_b, *d_o;
  cudaMalloc((void **)&d_o, size_bytes);

  cudaMalloc((void **)&d_a, size_bytes);
  cudaMemcpy(d_a, a, size_bytes, cudaMemcpyHostToDevice);

  cudaMalloc((void **)&d_b, size_bytes);
  cudaMemcpy(d_b, b, size_bytes, cudaMemcpyHostToDevice);

  dim3 grid((nrows + 255) / 256);
  dim3 block(256);

  cudaDeviceSynchronize();

  cross_kernel<<<grid, block>>>(nrows, d_o, d_a, d_b, 1, 1, 1);

  cudaDeviceSynchronize();

  cudaMemcpy(o, d_o, size_bytes, cudaMemcpyDeviceToHost);

  cross2_kernel<<<grid, block>>>(nrows, d_o, d_a, d_b, 1, 1, 1);

  cudaDeviceSynchronize();

  cudaMemcpy(o2, d_o, size_bytes, cudaMemcpyDeviceToHost);

  cal(nrows, o_cpu, a, b);

  cudaDeviceSynchronize();

  // warm up kernel
  for(int i= 0; i < WARMUPNUM; i++){
    cross3_kernel<<<grid, block>>>(nrows, d_o, d_a, d_b);
  }
  cudaDeviceSynchronize();

  GpuTimer timer;
  double total_time = 0;
  for (int i = 0; i < repeat; i++) {
    timer.start();
    cross3_kernel<<<grid, block>>>(nrows, d_o, d_a, d_b);
    timer.stop_and_wait();
    total_time += timer.duration(1);
  }

  cudaDeviceSynchronize();

  double avg_time = total_time / repeat;
  double performance = size_bytes * 1000. / 1024. / 1024. / 1024. / avg_time;

  if (std::is_same<T, float>::value) {
    printf("Average execution time of cross_fp32 kernels: %f ms\n", avg_time);
    printf("Average cross_fp32 kernel performance: %f GB/s\n", performance);
  }

  if (std::is_same<T, double>::value) {
    printf("Average cross_fp64 kernel performance: %f GB/s\n", performance);
  }

  cudaMemcpy(o3, d_o, size_bytes, cudaMemcpyDeviceToHost);

  bool ok = true;
  for (int i = 0; i < num_elems; i++) {
    if (fabs(o[i] - o2[i]) > 1e-3 || fabs(o[i] - o3[i]) > 1e-3 ||
        fabs(o3[i] - o_cpu[i]) > 1e-3) {
      ok = false;
      break;
    }
  }
  printf("%s\n", ok ? "PASS" : "FAIL");

  cudaFree(d_a);
  cudaFree(d_b);
  cudaFree(d_o);

  free(a);
  free(b);
  free(o);
  free(o2);
  free(o3);
}

int main(int argc, char *argv[]) {
  if (argc != 4) {
    printf("Usage: %s <number of rows in a 2D tensor> <repeat> <Data type: "
           "0(FP32), 1(FP64)>\n",
           argv[0]);
    return 1;
  }
  const int nrows = atoi(argv[1]);
  const int repeat = atoi(argv[2]);

  const int type = atoi(argv[3]);

  if (type == 0) {
    printf("=========== Data type is FP32 ==========\n");
    eval<float>(nrows, repeat);
  } else {

    printf("=========== Data type is FP64 ==========\n");
    eval<double>(nrows, repeat);
  }

  return 0;
}
